{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jVtJsW44SGf9"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_02_checkpoint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nm3HjQKdSGf9"
   },
   "source": [
    "# T81-558: Applications of Deep Neural Networks\n",
    "**Module 13: Advanced/Other Topics**\n",
    "* Instructor: [Jeff Heaton](https://sites.wustl.edu/jeffheaton/), McKelvey School of Engineering, [Washington University in St. Louis](https://engineering.wustl.edu/Programs/Pages/default.aspx)\n",
    "* For more information visit the [class website](https://sites.wustl.edu/jeffheaton/t81-558/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KqrrqF9ESGf-"
   },
   "source": [
    "# Module 13 Video Material\n",
    "\n",
    "* Part 13.1: Flask and Deep Learning Web Services [[Video]](https://www.youtube.com/watch?v=H73m9XvKHug&list=PLjy4p-07OYzulelvJ5KVaT2pDlxivl_BN) [[Notebook]](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_01_flask.ipynb)\n",
    "* **Part 13.2: Interrupting and Continuing Training** [[Video]](https://www.youtube.com/watch?v=kaQCdv46OBA&list=PLjy4p-07OYzulelvJ5KVaT2pDlxivl_BN) [[Notebook]](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_02_checkpoint.ipynb)\n",
    "* Part 13.3: Using a Keras Deep Neural Network with a Web Application  [[Video]](https://www.youtube.com/watch?v=OBbw0e-UroI&list=PLjy4p-07OYzulelvJ5KVaT2pDlxivl_BN) [[Notebook]](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_03_web.ipynb)\n",
    "* Part 13.4: When to Retrain Your Neural Network [[Video]](https://www.youtube.com/watch?v=K2Tjdx_1v9g&list=PLjy4p-07OYzulelvJ5KVaT2pDlxivl_BN) [[Notebook]](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_04_retrain.ipynb)\n",
    "* Part 13.5: Tensor Processing Units (TPUs)  [[Video]](https://www.youtube.com/watch?v=Ygyf3NUqvSc&list=PLjy4p-07OYzulelvJ5KVaT2pDlxivl_BN) [[Notebook]](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_13_05_tpu.ipynb)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1g6YREggSGf-"
   },
   "source": [
    "## Google CoLab Instructions\n",
    "The following code ensures that Google CoLab is running the correct version of TensorFlow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NjPmakc1SGf_",
    "outputId": "74593f37-d653-4ca2-d927-a2071c42aa08"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: using Google CoLab\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    from google.colab import drive\n",
    "    COLAB = True\n",
    "    print(\"Note: using Google CoLab\")\n",
    "    %tensorflow_version 2.x\n",
    "except:\n",
    "    print(\"Note: not using Google CoLab\")\n",
    "    COLAB = False\n",
    "\n",
    "# Nicely formatted time string\n",
    "def hms_string(sec_elapsed):\n",
    "    h = int(sec_elapsed / (60 * 60))\n",
    "    m = int((sec_elapsed % (60 * 60)) / 60)\n",
    "    s = sec_elapsed % 60\n",
    "    return f\"{h}:{m:>02}:{s:>05.2f}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "52zYAendSGf_"
   },
   "source": [
    "# Part 13.2: Interrupting and Continuing Training\n",
    "\n",
    "We would train our Keras models in one pass in an ideal world, utilizing as much GPU and CPU power as we need. The world in which we train our models is anything but ideal. In this part, we will see that we can stop and continue and even adjust training at later times. We accomplish this continuation with checkpoints. We begin by creating several utility functions. The first utility generates an output directory that has a unique name. This technique allows us to organize multiple runs of our experiment. We provide the Logger class to route output to a log file contained in the output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "vFp7XunVSGgA"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import sys\n",
    "import time\n",
    "import numpy as np\n",
    "from typing import Any, List, Tuple, Union\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from tensorflow.keras import backend as K\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import EarlyStopping, \\\n",
    "  LearningRateScheduler, ModelCheckpoint\n",
    "from tensorflow.keras import regularizers\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Dropout, Flatten\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D\n",
    "from tensorflow.keras.models import load_model\n",
    "import pickle\n",
    "\n",
    "def generate_output_dir(outdir, run_desc):\n",
    "    prev_run_dirs = []\n",
    "    if os.path.isdir(outdir):\n",
    "        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(\\\n",
    "            os.path.join(outdir, x))]\n",
    "    prev_run_ids = [re.match(r'^\\d+', x) for x in prev_run_dirs]\n",
    "    prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]\n",
    "    cur_run_id = max(prev_run_ids, default=-1) + 1\n",
    "    run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')\n",
    "    assert not os.path.exists(run_dir)\n",
    "    os.makedirs(run_dir)\n",
    "    return run_dir\n",
    "\n",
    "# From StyleGAN2\n",
    "class Logger(object):\n",
    "    \"\"\"Redirect stderr to stdout, optionally print stdout to a file, and \n",
    "    optionally force flushing on both stdout and the file.\"\"\"\n",
    "\n",
    "    def __init__(self, file_name: str = None, file_mode: str = \"w\", \\\n",
    "                 should_flush: bool = True):\n",
    "        self.file = None\n",
    "\n",
    "        if file_name is not None:\n",
    "            self.file = open(file_name, file_mode)\n",
    "\n",
    "        self.should_flush = should_flush\n",
    "        self.stdout = sys.stdout\n",
    "        self.stderr = sys.stderr\n",
    "\n",
    "        sys.stdout = self\n",
    "        sys.stderr = self\n",
    "\n",
    "    def __enter__(self) -> \"Logger\":\n",
    "        return self\n",
    "\n",
    "    def __exit__(self, exc_type: Any, exc_value: Any, \\\n",
    "                 traceback: Any) -> None:\n",
    "        self.close()\n",
    "\n",
    "    def write(self, text: str) -> None:\n",
    "        \"\"\"Write text to stdout (and a file) and optionally flush.\"\"\"\n",
    "        if len(text) == 0: \n",
    "            return\n",
    "\n",
    "        if self.file is not None:\n",
    "            self.file.write(text)\n",
    "\n",
    "        self.stdout.write(text)\n",
    "\n",
    "        if self.should_flush:\n",
    "            self.flush()\n",
    "\n",
    "    def flush(self) -> None:\n",
    "        \"\"\"Flush written text to both stdout and a file, if open.\"\"\"\n",
    "        if self.file is not None:\n",
    "            self.file.flush()\n",
    "\n",
    "        self.stdout.flush()\n",
    "\n",
    "    def close(self) -> None:\n",
    "        \"\"\"Flush, close possible files, and remove \n",
    "            stdout/stderr mirroring.\"\"\"\n",
    "        self.flush()\n",
    "\n",
    "        # if using multiple loggers, prevent closing in wrong order\n",
    "        if sys.stdout is self:\n",
    "            sys.stdout = self.stdout\n",
    "        if sys.stderr is self:\n",
    "            sys.stderr = self.stderr\n",
    "\n",
    "        if self.file is not None:\n",
    "            self.file.close()\n",
    "\n",
    "def obtain_data():\n",
    "    (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "    print(\"Shape of x_train: {}\".format(x_train.shape))\n",
    "    print(\"Shape of y_train: {}\".format(y_train.shape))\n",
    "    print()\n",
    "    print(\"Shape of x_test: {}\".format(x_test.shape))\n",
    "    print(\"Shape of y_test: {}\".format(y_test.shape))\n",
    "\n",
    "    # input image dimensions\n",
    "    img_rows, img_cols = 28, 28\n",
    "    if K.image_data_format() == 'channels_first':\n",
    "        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n",
    "        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n",
    "        input_shape = (1, img_rows, img_cols)\n",
    "    else:\n",
    "        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
    "        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
    "        input_shape = (img_rows, img_cols, 1)\n",
    "    x_train = x_train.astype('float32')\n",
    "    x_test = x_test.astype('float32')\n",
    "    x_train /= 255\n",
    "    x_test /= 255\n",
    "    print('x_train shape:', x_train.shape)\n",
    "    print(\"Training samples: {}\".format(x_train.shape[0]))\n",
    "    print(\"Test samples: {}\".format(x_test.shape[0]))\n",
    "    # convert class vectors to binary class matrices\n",
    "    y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n",
    "    y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n",
    "    \n",
    "    return input_shape, x_train, y_train, x_test, y_test\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dD2MQQu0SGgB"
   },
   "source": [
    "We define the basic training parameters and where we wish to write the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1CO2O1G1SGgC",
    "outputId": "5a00deb7-6aba-4203-aa11-d211a9c022e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results saved to: ./data/00000-test-train\n"
     ]
    }
   ],
   "source": [
    "outdir = \"./data/\"\n",
    "run_desc = \"test-train\"\n",
    "batch_size = 128\n",
    "num_classes = 10\n",
    "\n",
    "run_dir = generate_output_dir(outdir, run_desc)\n",
    "print(f\"Results saved to: {run_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Kdh4owUnSGgC"
   },
   "source": [
    "Keras provides a prebuilt checkpoint class named **ModelCheckpoint** that contains most of our desired functionality. This built-in class can save the model's state repeatedly as training progresses. Stopping neural network training is not always a controlled event. Sometimes this stoppage can be abrupt, such as a power failure or a network resource shutting down. If Microsoft Windows is your operating system of choice, your training can also be interrupted by a high-priority system update. Because of all of this uncertainty, it is best to save your model at regular intervals. This process is similar to saving a game at critical checkpoints, so you do not have to start over if something terrible happens to your avatar in the game.\n",
    "\n",
    "We will create our checkpoint class, named **MyModelCheckpoint**. In addition to saving the model, we also save the state of the training infrastructure. Why save the training infrastructure, in addition to the weights? This technique eases the transition back into training for the neural network and will be more efficient than a cold start.  \n",
    "\n",
    "Consider if you interrupted your college studies after the first year. Sure, your brain (the neural network) will retain all the knowledge. But how much rework will you have to do? Your transcript at the university is like the training parameters. It ensures you do not have to start over when you come back."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "CjH7sV3NYYws"
   },
   "outputs": [],
   "source": [
    "class MyModelCheckpoint(ModelCheckpoint):\n",
    "  def __init__(self, *args, **kwargs):\n",
    "    super().__init__(*args, **kwargs)\n",
    "\n",
    "  def on_epoch_end(self, epoch, logs=None):\n",
    "    super().on_epoch_end(epoch,logs)\\\n",
    "\n",
    "    # Also save the optimizer state\n",
    "    filepath = self._get_file_path(epoch=epoch, \n",
    "        logs=logs, batch=None)\n",
    "    filepath = filepath.rsplit( \".\", 1 )[ 0 ] \n",
    "    filepath += \".pkl\"\n",
    "\n",
    "    with open(filepath, 'wb') as fp:\n",
    "      pickle.dump(\n",
    "        {\n",
    "          'opt': model.optimizer.get_config(),\n",
    "          'epoch': epoch+1\n",
    "         # Add additional keys if you need to store more values\n",
    "        }, fp, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "    print('\\nEpoch %05d: saving optimizaer to %s' % (epoch + 1, filepath))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "30umhi7bxviZ"
   },
   "source": [
    "The optimizer applies a step decay schedule during training to decrease the learning rate as training progresses.  It is essential to preserve the current epoch that we are on to perform correctly after a training resume."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "8eJKD5S774lD"
   },
   "outputs": [],
   "source": [
    "def step_decay_schedule(initial_lr=1e-3, decay_factor=0.75, step_size=10):\n",
    "    def schedule(epoch):\n",
    "        return initial_lr * (decay_factor ** np.floor(epoch/step_size))\n",
    "    return LearningRateScheduler(schedule)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TJpyn-tbxwQ-"
   },
   "source": [
    "We build the model just as we have in previous sessions.  However, the training function requires a few extra considerations.  We specify the maximum number of epochs; however, we also allow the user to select the starting epoch number for training continuation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "_TCoWl3APWV2"
   },
   "outputs": [],
   "source": [
    "def build_model(input_shape, num_classes):\n",
    "    model = Sequential()\n",
    "    model.add(Conv2D(32, kernel_size=(3, 3),\n",
    "                     activation='relu',\n",
    "                     input_shape=input_shape))\n",
    "    model.add(Conv2D(64, (3, 3), activation='relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "    model.add(Dropout(0.25))\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(128, activation='relu'))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Dense(num_classes, activation='softmax'))\n",
    "    model.compile(\n",
    "        loss='categorical_crossentropy', \n",
    "        optimizer=tf.keras.optimizers.Adam(),\n",
    "        metrics=['accuracy'])\n",
    "    return model\n",
    "\n",
    "def train_model(model, initial_epoch=0, max_epochs=10):\n",
    "    start_time = time.time()\n",
    "\n",
    "    checkpoint_cb = MyModelCheckpoint(\n",
    "        os.path.join(run_dir, 'model-{epoch:02d}-{val_loss:.2f}.hdf5'),\n",
    "        monitor='val_loss',verbose=1)\n",
    "\n",
    "    lr_sched_cb = step_decay_schedule(initial_lr=1e-4, decay_factor=0.75, \\\n",
    "                                      step_size=2)\n",
    "    cb = [checkpoint_cb, lr_sched_cb]\n",
    "\n",
    "    model.fit(x_train, y_train,\n",
    "              batch_size=batch_size,\n",
    "              epochs=max_epochs,\n",
    "              initial_epoch = initial_epoch,\n",
    "              verbose=2, callbacks=cb,\n",
    "              validation_data=(x_test, y_test))\n",
    "    score = model.evaluate(x_test, y_test, verbose=0, callbacks=cb)\n",
    "    print('Test loss: {}'.format(score[0]))\n",
    "    print('Test accuracy: {}'.format(score[1]))\n",
    "\n",
    "    elapsed_time = time.time() - start_time\n",
    "    print(\"Elapsed time: {}\".format(hms_string(elapsed_time)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dL7pH0bHSGgD"
   },
   "source": [
    "We now begin training, using the **Logger** class to write the output to a log file in the output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VD6E_JOySGgD",
    "outputId": "6ceca138-773c-44f6-f46f-176a34fd8b8d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
      "11493376/11490434 [==============================] - 0s 0us/step\n",
      "11501568/11490434 [==============================] - 0s 0us/step\n",
      "Shape of x_train: (60000, 28, 28)\n",
      "Shape of y_train: (60000,)\n",
      "\n",
      "Shape of x_test: (10000, 28, 28)\n",
      "Shape of y_test: (10000,)\n",
      "x_train shape: (60000, 28, 28, 1)\n",
      "Training samples: 60000\n",
      "Test samples: 10000\n",
      "Epoch 1/3\n",
      "\n",
      "Epoch 1: saving model to ./data/00000-test-train/model-01-0.20.hdf5\n",
      "\n",
      "Epoch 00001: saving optimizaer to ./data/00000-test-train/model-01-0.20.pkl\n",
      "469/469 - 12s - loss: 0.6354 - accuracy: 0.8129 - val_loss: 0.1977 - val_accuracy: 0.9420 - lr: 1.0000e-04 - 12s/epoch - 25ms/step\n",
      "Epoch 2/3\n",
      "\n",
      "Epoch 2: saving model to ./data/00000-test-train/model-02-0.11.hdf5\n",
      "\n",
      "Epoch 00002: saving optimizaer to ./data/00000-test-train/model-02-0.11.pkl\n",
      "469/469 - 2s - loss: 0.2284 - accuracy: 0.9332 - val_loss: 0.1087 - val_accuracy: 0.9677 - lr: 1.0000e-04 - 2s/epoch - 5ms/step\n",
      "Epoch 3/3\n",
      "\n",
      "Epoch 3: saving model to ./data/00000-test-train/model-03-0.08.hdf5\n",
      "\n",
      "Epoch 00003: saving optimizaer to ./data/00000-test-train/model-03-0.08.pkl\n",
      "469/469 - 2s - loss: 0.1575 - accuracy: 0.9541 - val_loss: 0.0837 - val_accuracy: 0.9746 - lr: 7.5000e-05 - 2s/epoch - 5ms/step\n",
      "Test loss: 0.08365701138973236\n",
      "Test accuracy: 0.9746000170707703\n",
      "Elapsed time: 0:00:22.09\n"
     ]
    }
   ],
   "source": [
    "with Logger(os.path.join(run_dir, 'log.txt')):\n",
    "    input_shape, x_train, y_train, x_test, y_test = obtain_data()\n",
    "    model = build_model(input_shape, num_classes)\n",
    "    train_model(model, max_epochs=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p43vw1Rk5f-X"
   },
   "source": [
    "You should notice that the above output displays the name of the hdf5 and pickle (pkl) files produced at each checkpoint. These files serve the following functions:\n",
    "\n",
    "* Pickle files contain the state of the optimizer.\n",
    "* HDF5 files contain the saved model.\n",
    "\n",
    "For this training run, which went for 3 epochs, these two files were named:\n",
    "\n",
    "* ./data/00013-test-train/model-03-0.08.hdf5\n",
    "* ./data/00013-test-train/model-03-0.08.pkl\n",
    "\n",
    "We can inspect the output from the training run. Notice we can see a folder named \"00000-test-train\". This new folder was the first training run. The program will call the next training run \"00001-test-train\", and so on. Inside this directory, you will find the pickle and hdf5 files for each checkpoint.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KSZM0ZVQevwM",
    "outputId": "c81b17d1-809a-48b2-f294-72117e5d8c19"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00000-test-train\n"
     ]
    }
   ],
   "source": [
    "!ls ./data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YsNFljXksBBy",
    "outputId": "d24873a5-9c37-476b-cc7a-cc0864690732"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log.txt\t\t    model-01-0.20.pkl\tmodel-02-0.11.pkl   model-03-0.08.pkl\n",
      "model-01-0.20.hdf5  model-02-0.11.hdf5\tmodel-03-0.08.hdf5\n"
     ]
    }
   ],
   "source": [
    "!ls ./data/00000-test-train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zQf9LXjmYhgE"
   },
   "source": [
    "Keras stores the model itself in an HDF5, which includes the optimizer. Because of this feature, it is not generally necessary to restore the internal state of the optimizer (such as ADAM). However, we include the code to do so. We can obtain the internal state of an optimizer by calling **get_config**, which will return a dictionary similar to the following:\n",
    "\n",
    "```\n",
    "{'name': 'Adam', 'learning_rate': 7.5e-05, 'decay': 0.0, \n",
    "'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}\n",
    "```\n",
    "\n",
    "In practice, I've found that different optimizers implement get_config differently. This function will always return the training hyperparameters. However, it may not always capture the complete internal state of an optimizer beyond the hyperparameters. The exact implementation of get_config can vary per optimizer implementation.\n",
    "\n",
    "## Continuing Training\n",
    "\n",
    "We are now ready to continue training. You will need the paths to both your HDF5 and PKL files. You can find these paths in the output above. Your values may differ from mine, so perform a copy/paste.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "LqYPnKw890n5"
   },
   "outputs": [],
   "source": [
    "MODEL_PATH = './data/00000-test-train/model-03-0.08.hdf5'\n",
    "OPT_PATH = './data/00000-test-train/model-03-0.08.pkl'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_SkIbReAZhif"
   },
   "source": [
    "The following code loads the HDF5 and PKL files and then recompiles the model based on the PKL file.  Depending on the optimizer in use, you might have to recompile the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "KnuX1BHZ-BMK"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import load_model\n",
    "import pickle\n",
    "\n",
    "def load_model_data(model_path, opt_path):\n",
    "    model = load_model(model_path)\n",
    "    with open(opt_path, 'rb') as fp:\n",
    "      d = pickle.load(fp)\n",
    "      epoch = d['epoch']\n",
    "      opt = d['opt']\n",
    "      return epoch, model, opt\n",
    "\n",
    "epoch, model, opt = load_model_data(MODEL_PATH, OPT_PATH)\n",
    "\n",
    "# note: often it is not necessary to recompile the model\n",
    "model.compile(\n",
    "    loss='categorical_crossentropy', \n",
    "    optimizer=tf.keras.optimizers.Adam.from_config(opt),\n",
    "    metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "x5VSWcMeZykd"
   },
   "source": [
    "Finally, we train the model for additional epochs.  You can see from the output that the new training starts at a higher accuracy than the first training run.  Further, the accuracy increases with additional training.  Also, you will notice that the epoch number begins at four and not one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VSNxnP1MurcD",
    "outputId": "6a6504a5-3657-4552-ad82-f90675773eb1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results saved to: ./data/00001-cont-train\n",
      "Shape of x_train: (60000, 28, 28)\n",
      "Shape of y_train: (60000,)\n",
      "\n",
      "Shape of x_test: (10000, 28, 28)\n",
      "Shape of y_test: (10000,)\n",
      "x_train shape: (60000, 28, 28, 1)\n",
      "Training samples: 60000\n",
      "Test samples: 10000\n",
      "Epoch 4/6\n",
      "\n",
      "Epoch 4: saving model to ./data/00001-cont-train/model-04-0.07.hdf5\n",
      "\n",
      "Epoch 00004: saving optimizaer to ./data/00001-cont-train/model-04-0.07.pkl\n",
      "469/469 - 3s - loss: 0.1272 - accuracy: 0.9634 - val_loss: 0.0692 - val_accuracy: 0.9788 - lr: 7.5000e-05 - 3s/epoch - 6ms/step\n",
      "Epoch 5/6\n",
      "\n",
      "Epoch 5: saving model to ./data/00001-cont-train/model-05-0.06.hdf5\n",
      "\n",
      "Epoch 00005: saving optimizaer to ./data/00001-cont-train/model-05-0.06.pkl\n",
      "469/469 - 2s - loss: 0.1099 - accuracy: 0.9677 - val_loss: 0.0612 - val_accuracy: 0.9818 - lr: 5.6250e-05 - 2s/epoch - 5ms/step\n",
      "Epoch 6/6\n",
      "\n",
      "Epoch 6: saving model to ./data/00001-cont-train/model-06-0.06.hdf5\n",
      "\n",
      "Epoch 00006: saving optimizaer to ./data/00001-cont-train/model-06-0.06.pkl\n",
      "469/469 - 2s - loss: 0.0990 - accuracy: 0.9711 - val_loss: 0.0561 - val_accuracy: 0.9827 - lr: 5.6250e-05 - 2s/epoch - 5ms/step\n",
      "Test loss: 0.05610647052526474\n",
      "Test accuracy: 0.982699990272522\n",
      "Elapsed time: 0:00:11.72\n"
     ]
    }
   ],
   "source": [
    "outdir = \"./data/\"\n",
    "run_desc = \"cont-train\"\n",
    "num_classes = 10\n",
    "\n",
    "run_dir = generate_output_dir(outdir, run_desc)\n",
    "print(f\"Results saved to: {run_dir}\")\n",
    "\n",
    "with Logger(os.path.join(run_dir, 'log.txt')):\n",
    "  input_shape, x_train, y_train, x_test, y_test = obtain_data()\n",
    "  train_model(model, initial_epoch=epoch, max_epochs=6)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "anaconda-cloud": {},
  "colab": {
   "collapsed_sections": [],
   "name": "Copy of Copy of t81_558_class_13_02_checkpoint.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3.9 (tensorflow)",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
